# coding:utf8
import tensorflow as tf
from PIL import Image,ImageFilter
import matplotlib.pyplot as plt
import time
import numpy as np

model_path = "./isEndModel/model/model.ckpt" #模型文件
CLASS = 2

g1=tf.Graph()

#读取二进制数据
def read_and_decode(filename):
    with g1.as_default():  
        # 创建文件队列,不限读取的数量
        filename_queue = tf.train.string_input_producer([filename])
        # create a reader from file queue
        reader = tf.TFRecordReader()
        # reader从文件队列中读入一个序列化的样本
        _, serialized_example = reader.read(filename_queue)
        # get feature from serialized example
        # 解析符号化的样本
        features = tf.parse_single_example(
            serialized_example,
            features={
                'label': tf.FixedLenFeature([], tf.int64),
                'img_raw': tf.FixedLenFeature([], tf.string)
            }
        )
        label = features['label']
        label = tf.cast(label, tf.int32)
        label = tf.one_hot(label,CLASS,1,0) 
        img = features['img_raw']    
        img = tf.decode_raw(img, tf.uint8)
        img = tf.reshape(img, [28,28, 1])
        img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
        
        return img, label

"""
权重初始化
初始化为一个接近0的很小的正数
"""
def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev = 0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    initial = tf.constant(0.1, shape = shape)
    return tf.Variable(initial)

"""
卷积和池化，使用卷积步长为1（stride size）,0边距（padding size）
池化用简单传统的2x2大小的模板做max pooling
"""
def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding = 'SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize = [1, 2, 2, 1],
                          strides = [1, 2, 2, 1], padding = 'SAME')


# 在计算图g1中定义张量和操作  
with g1.as_default():  
    x = tf.placeholder("float", shape=[None, 28,28,1])
    y_ = tf.placeholder("float", shape=[None,2])

    """
    第一层 卷积层
    """
    W_conv1 = weight_variable([5, 5, 1, 32])
    b_conv1 = bias_variable([32])

    h_conv1 = tf.nn.relu(conv2d(x, W_conv1) + b_conv1)
    h_pool1 = max_pool_2x2(h_conv1)

    """
    第二层 卷积层
    """
    W_conv2 = weight_variable([5, 5, 32, 64])
    b_conv2 = bias_variable([64])

    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
    h_pool2 = max_pool_2x2(h_conv2)

    """
    第三层 全连接层
    """
    W_fc1 = weight_variable([7 * 7 * 64, 1024])
    b_fc1 = bias_variable([1024])

    h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

    """
    Dropout
    """
    keep_prob = tf.placeholder("float")
    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

    """
    第四层 Softmax输出层
    """
    W_fc2 = weight_variable([1024, 2])
    b_fc2 = bias_variable([2])

    y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) 
     
    cross_entropy = tf.reduce_mean(tf.square(y_- y_conv))
        
    train_step = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy)
    correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))  
    pred = tf.argmax(y_conv,1)
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 

    saver = tf.train.Saver()
    
    
    sess = tf.Session(graph=g1)    
    #sess.run(tf.global_variables_initializer())#第一次运行不要屏蔽
    saver.restore(sess, model_path)#恢复模型  第一次运行屏蔽
                      
#预测
def predict(img):
    with g1.as_default():
        img = img.crop((100, 1600, 300, 1800))
        img = img.convert('L')
        img = img.resize((28, 28))
        img = np.array(img)
        img = tf.constant(img)
        img = tf.reshape(img, [28, 28, 1])
        img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
        img = tf.expand_dims(img, 0)
        img = sess.run(img)
        result = sess.run([pred],feed_dict={x:img,keep_prob: 1})
        return result[0][0]
    
# 训练
def train():
    with sess.as_default():
        batch_size = 15        
        img, label = read_and_decode("./isEndModel/data/train.tfrecords")
        img_batch, label_batch = tf.train.shuffle_batch([img, label],
                                                    batch_size=batch_size, capacity=2000,
                                                    min_after_dequeue=1000)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for i in range(5000): #开始训练模型，循环训练5000次
            train_batch_x, train_batch_y = sess.run([img_batch, label_batch])
            sess.run([train_step],feed_dict={x:train_batch_x, y_: train_batch_y,keep_prob: 0.5})        
            if i % 100 == 0:
                p,accuracy2 = sess.run([pred,accuracy],feed_dict={x:train_batch_x, y_: train_batch_y,keep_prob: 1})
                print(train_batch_y)
                print(p)
                print(i, accuracy2)
                save_path = saver.save(sess, model_path)#保存模型
            
                    

        coord.request_stop()
        coord.join(threads)
    
    
    
    
if __name__ == '__main__':
    print(predict(Image.open("./isEndModel/data/img/0001-1.png")))
    #train()


